{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Tce3stUlHN0L"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "tuOe1ymfHZPu"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qFdPvlXBOdUN"
      },
      "source": [
        "# Effective Tensorflow 2"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MfBg1C5NB3X0"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/guide/effective_tf2\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/guide/effective_tf2.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/guide/effective_tf2.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/en/guide/effective_tf2.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xHxb-dlhMIzW"
      },
      "source": [
        "## Overview\n",
        "\n",
        "This guide provides a list of best practices for writing code using TensorFlow 2 (TF2), it is written for users who have recently switched over from TensorFlow 1 (TF1).  Refer to the [migrate section of the guide](https://tensorflow.org/guide/migrate) for more info on migrating your TF1 code to TF2."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MUXex9ctTuDB"
      },
      "source": [
        "## Setup\n",
        "\n",
        "Import TensorFlow and other dependencies for the examples in this guide."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IqR2PQG4ZaZ0"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "import tensorflow_datasets as tfds"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ngds9zateIY8"
      },
      "source": [
        "## Recommendations for idiomatic TensorFlow 2"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "B3RdHaroMAi4"
      },
      "source": [
        "### Refactor your code into smaller modules\n",
        "\n",
        "A good practice is to refactor your code into smaller functions that are called as needed. For best performance, you should try to decorate the largest blocks of computation that you can in a `tf.function` (note that the nested python functions called by a `tf.function` do not require their own separate decorations, unless you want to use different `jit_compile` settings for the `tf.function`). Depending on your use case, this could be multiple training steps or even your whole training loop. For inference use cases, it might be a single model forward pass."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Rua1l8et3Evd"
      },
      "source": [
        "### Adjust the default learning rate for some `tf.keras.optimizer`s\n",
        "<a name=\"optimizer_defaults\"></a>\n",
        "\n",
        "Some Keras optimizers have different learning rates in TF2. If you see a change in convergence behavior for your models, check the default learning rates.\n",
        "\n",
        "There are no changes for `optimizers.SGD`, `optimizers.Adam`, or `optimizers.RMSprop`.\n",
        "\n",
        "The following default learning rates have changed:\n",
        "\n",
        "- `optimizers.Adagrad` from `0.01` to `0.001`\n",
        "- `optimizers.Adadelta` from `1.0` to `0.001`\n",
        "- `optimizers.Adamax` from `0.002` to `0.001`\n",
        "- `optimizers.Nadam` from `0.002` to `0.001`"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "z6LfkpsEldEV"
      },
      "source": [
        "### Use `tf.Module`s and Keras layers to manage variables\n",
        "\n",
        "`tf.Module`s and `tf.keras.layers.Layer`s offer the convenient `variables` and\n",
        "`trainable_variables` properties, which recursively gather up all dependent\n",
        "variables. This makes it easy to manage variables locally to where they are\n",
        "being used."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WQ2U0rj1oBlc"
      },
      "source": [
        "Keras layers/models inherit from `tf.train.Checkpointable` and are integrated\n",
        "with `@tf.function`, which makes it possible to directly checkpoint or export\n",
        "SavedModels from Keras objects. You do not necessarily have to use Keras'\n",
        "`Model.fit` API to take advantage of these integrations.\n",
        "\n",
        "Read the section on [transfer learning and fine-tuning](https://www.tensorflow.org/guide/keras/transfer_learning#transfer_learning_fine-tuning_with_a_custom_training_loop) in the Keras guide to learn how to collect a subset of relevant variables using Keras."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "j34MsfxWodG6"
      },
      "source": [
        "### Combine `tf.data.Dataset`s and `tf.function`\n",
        "\n",
        "The [TensorFlow Datasets](https://tensorflow.org/datasets) package (`tfds`) contains utilities for loading predefined datasets as `tf.data.Dataset` objects. For this example, you can load the MNIST dataset using `tfds`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BMgxaLH74_s-"
      },
      "outputs": [],
      "source": [
        "datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)\n",
        "mnist_train, mnist_test = datasets['train'], datasets['test']"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hPJhEuvj5VfR"
      },
      "source": [
        "Then prepare the data for training:\n",
        "\n",
        "  - Re-scale each image.\n",
        "  - Shuffle the order of the examples.\n",
        "  - Collect batches of images and labels.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "StBRHtJM2S7o"
      },
      "outputs": [],
      "source": [
        "BUFFER_SIZE = 10 # Use a much larger value for real code\n",
        "BATCH_SIZE = 64\n",
        "NUM_EPOCHS = 5\n",
        "\n",
        "\n",
        "def scale(image, label):\n",
        "  image = tf.cast(image, tf.float32)\n",
        "  image /= 255\n",
        "\n",
        "  return image, label"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SKq14zKKFAdv"
      },
      "source": [
        "To keep the example short, trim the dataset to only return 5 batches:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_J-o4YjG2mkM"
      },
      "outputs": [],
      "source": [
        "train_data = mnist_train.map(scale).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)\n",
        "test_data = mnist_test.map(scale).batch(BATCH_SIZE)\n",
        "\n",
        "STEPS_PER_EPOCH = 5\n",
        "\n",
        "train_data = train_data.take(STEPS_PER_EPOCH)\n",
        "test_data = test_data.take(STEPS_PER_EPOCH)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XEqdkH54VM6c"
      },
      "outputs": [],
      "source": [
        "image_batch, label_batch = next(iter(train_data))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "loTPH2Pz4_Oj"
      },
      "source": [
        "Use regular Python iteration to iterate over training data that fits in memory. Otherwise, `tf.data.Dataset` is the best way to stream training data from disk. Datasets are [iterables (not iterators)](https://docs.python.org/3/glossary.html#term-iterable), and work just like other Python iterables in eager execution. You can fully utilize dataset async prefetching/streaming features by wrapping your code in `tf.function`, which replaces Python iteration with the equivalent graph operations using AutoGraph.\n",
        "\n",
        "```python\n",
        "@tf.function\n",
        "def train(model, dataset, optimizer):\n",
        "  for x, y in dataset:\n",
        "    with tf.GradientTape() as tape:\n",
        "      # training=True is only needed if there are layers with different\n",
        "      # behavior during training versus inference (e.g. Dropout).\n",
        "      prediction = model(x, training=True)\n",
        "      loss = loss_fn(prediction, y)\n",
        "    gradients = tape.gradient(loss, model.trainable_variables)\n",
        "    optimizer.apply_gradients(zip(gradients, model.trainable_variables))\n",
        "```\n",
        "\n",
        "If you use the Keras `Model.fit` API, you won't have to worry about dataset\n",
        "iteration.\n",
        "\n",
        "```python\n",
        "model.compile(optimizer=optimizer, loss=loss_fn)\n",
        "model.fit(dataset)\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mSev7vZC5GJB"
      },
      "source": [
        "<a name=\"keras_training_loops\"></a>\n",
        "### Use Keras training loops\n",
        "\n",
        "If you don't need low-level control of your training process, using Keras' built-in `fit`, `evaluate`, and `predict` methods is recommended. These methods provide a uniform interface to train the model regardless of the implementation (sequential,  functional, or sub-classed).\n",
        "\n",
        "The advantages of these methods include:\n",
        "\n",
        "- They accept Numpy arrays, Python generators and, `tf.data.Datasets`.\n",
        "- They apply regularization, and activation losses automatically.\n",
        "- They support `tf.distribute` where the training code remains the same [regardless of the hardware configuration](distributed_training.ipynb).\n",
        "- They support arbitrary callables as losses and metrics.\n",
        "- They support callbacks like `tf.keras.callbacks.TensorBoard`, and custom callbacks.\n",
        "- They are performant, automatically using TensorFlow graphs.\n",
        "\n",
        "Here is an example of training a model using a `Dataset`. For details on how this works, check out the [tutorials](https://tensorflow.org/tutorials)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uzHFCzd45Rae"
      },
      "outputs": [],
      "source": [
        "model = tf.keras.Sequential([\n",
        "    tf.keras.layers.Conv2D(32, 3, activation='relu',\n",
        "                           kernel_regularizer=tf.keras.regularizers.l2(0.02),\n",
        "                           input_shape=(28, 28, 1)),\n",
        "    tf.keras.layers.MaxPooling2D(),\n",
        "    tf.keras.layers.Flatten(),\n",
        "    tf.keras.layers.Dropout(0.1),\n",
        "    tf.keras.layers.Dense(64, activation='relu'),\n",
        "    tf.keras.layers.BatchNormalization(),\n",
        "    tf.keras.layers.Dense(10)\n",
        "])\n",
        "\n",
        "# Model is the full model w/o custom layers\n",
        "model.compile(optimizer='adam',\n",
        "              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
        "              metrics=['accuracy'])\n",
        "\n",
        "model.fit(train_data, epochs=NUM_EPOCHS)\n",
        "loss, acc = model.evaluate(test_data)\n",
        "\n",
        "print(\"Loss {}, Accuracy {}\".format(loss, acc))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LQTaHTuK5S5A"
      },
      "source": [
        "<a name=\"custom_loop\"></a>\n",
        "\n",
        "### Customize training and write your own loop\n",
        "\n",
        "If Keras models work for you, but you need more flexibility and control of the training step or the outer training loops, you can implement your own training steps or even entire training loops. See the Keras guide on [customizing `fit`](https://www.tensorflow.org/guide/keras/customizing_what_happens_in_fit) to learn more.\n",
        "\n",
        "You can also implement many things as a `tf.keras.callbacks.Callback`.\n",
        "\n",
        "This method has many of the advantages [mentioned previously](#keras_training_loops), but gives you control of the train step and even the outer loop.\n",
        "\n",
        "There are three steps to a standard training loop:\n",
        "\n",
        "1. Iterate over a Python generator or `tf.data.Dataset` to get batches of examples.\n",
        "2. Use `tf.GradientTape` to collect gradients.\n",
        "3. Use one of the `tf.keras.optimizers` to apply weight updates to the model's variables.\n",
        "\n",
        "Remember:\n",
        "\n",
        "- Always include a `training` argument on the `call` method of subclassed layers and models.\n",
        "- Make sure to call the model with the `training` argument set correctly.\n",
        "- Depending on usage, model variables may not exist until the model is run on a batch of data.\n",
        "- You need to manually handle things like regularization losses for the model.\n",
        "\n",
        "There is no need to run variable initializers or to add manual control dependencies. `tf.function` handles automatic control dependencies and variable initialization on creation for you."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gQooejfYlQeF"
      },
      "outputs": [],
      "source": [
        "model = tf.keras.Sequential([\n",
        "    tf.keras.layers.Conv2D(32, 3, activation='relu',\n",
        "                           kernel_regularizer=tf.keras.regularizers.l2(0.02),\n",
        "                           input_shape=(28, 28, 1)),\n",
        "    tf.keras.layers.MaxPooling2D(),\n",
        "    tf.keras.layers.Flatten(),\n",
        "    tf.keras.layers.Dropout(0.1),\n",
        "    tf.keras.layers.Dense(64, activation='relu'),\n",
        "    tf.keras.layers.BatchNormalization(),\n",
        "    tf.keras.layers.Dense(10)\n",
        "])\n",
        "\n",
        "optimizer = tf.keras.optimizers.Adam(0.001)\n",
        "loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n",
        "\n",
        "@tf.function\n",
        "def train_step(inputs, labels):\n",
        "  with tf.GradientTape() as tape:\n",
        "    predictions = model(inputs, training=True)\n",
        "    regularization_loss=tf.math.add_n(model.losses)\n",
        "    pred_loss=loss_fn(labels, predictions)\n",
        "    total_loss=pred_loss + regularization_loss\n",
        "\n",
        "  gradients = tape.gradient(total_loss, model.trainable_variables)\n",
        "  optimizer.apply_gradients(zip(gradients, model.trainable_variables))\n",
        "\n",
        "for epoch in range(NUM_EPOCHS):\n",
        "  for inputs, labels in train_data:\n",
        "    train_step(inputs, labels)\n",
        "  print(\"Finished epoch\", epoch)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WikxMFGgo3oZ"
      },
      "source": [
        "### Take advantage of `tf.function` with Python control flow\n",
        "\n",
        "`tf.function` provides a way to convert data-dependent control flow into graph-mode\n",
        "equivalents like `tf.cond` and `tf.while_loop`.\n",
        "\n",
        "One common place where data-dependent control flow appears is in sequence\n",
        "models. `tf.keras.layers.RNN` wraps an RNN cell, allowing you to either\n",
        "statically or dynamically unroll the recurrence. As an example, you could reimplement dynamic unroll as follows."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "n5UebfChRu4T"
      },
      "outputs": [],
      "source": [
        "class DynamicRNN(tf.keras.Model):\n",
        "\n",
        "  def __init__(self, rnn_cell):\n",
        "    super(DynamicRNN, self).__init__(self)\n",
        "    self.cell = rnn_cell\n",
        "\n",
        "  @tf.function(input_signature=[tf.TensorSpec(dtype=tf.float32, shape=[None, None, 3])])\n",
        "  def call(self, input_data):\n",
        "\n",
        "    # [batch, time, features] -> [time, batch, features]\n",
        "    input_data = tf.transpose(input_data, [1, 0, 2])\n",
        "    timesteps =  tf.shape(input_data)[0]\n",
        "    batch_size = tf.shape(input_data)[1]\n",
        "    outputs = tf.TensorArray(tf.float32, timesteps)\n",
        "    state = self.cell.get_initial_state(batch_size = batch_size, dtype=tf.float32)\n",
        "    for i in tf.range(timesteps):\n",
        "      output, state = self.cell(input_data[i], state)\n",
        "      outputs = outputs.write(i, output)\n",
        "    return tf.transpose(outputs.stack(), [1, 0, 2]), state"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NhBI_SGKQVIB"
      },
      "outputs": [],
      "source": [
        "lstm_cell = tf.keras.layers.LSTMCell(units = 13)\n",
        "\n",
        "my_rnn = DynamicRNN(lstm_cell)\n",
        "outputs, state = my_rnn(tf.random.normal(shape=[10,20,3]))\n",
        "print(outputs.shape)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "du7bn3NX7iIr"
      },
      "source": [
        "Read the [`tf.function` guide](https://www.tensorflow.org/guide/function) for a more information."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SUAYhgL_NomT"
      },
      "source": [
        "### New-style metrics and losses\n",
        "\n",
        "Metrics and losses are both objects that work eagerly and in `tf.function`s.\n",
        "\n",
        "A loss object is callable, and expects (`y_true`, `y_pred`) as arguments:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Pf5gcwMzNs8F"
      },
      "outputs": [],
      "source": [
        "cce = tf.keras.losses.CategoricalCrossentropy(from_logits=True)\n",
        "cce([[1, 0]], [[-1.0,3.0]]).numpy()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a89m-wRfxyfV"
      },
      "source": [
        "#### Use metrics to collect and display data\n",
        "\n",
        "You can use `tf.metrics` to aggregate data and `tf.summary` to log summaries and redirect it to a writer using a context manager. The summaries are emitted directly to the writer which means that you must provide the `step` value at the callsite.\n",
        "\n",
        "```python\n",
        "summary_writer = tf.summary.create_file_writer('/tmp/summaries')\n",
        "with summary_writer.as_default():\n",
        "  tf.summary.scalar('loss', 0.1, step=42)\n",
        "```\n",
        "\n",
        "Use `tf.metrics` to aggregate data before logging them as summaries. Metrics are stateful; they accumulate values and return a cumulative result when you call the `result` method (such as `Mean.result`). Clear accumulated values with `Model.reset_states`.\n",
        "\n",
        "```python\n",
        "def train(model, optimizer, dataset, log_freq=10):\n",
        "  avg_loss = tf.keras.metrics.Mean(name='loss', dtype=tf.float32)\n",
        "  for images, labels in dataset:\n",
        "    loss = train_step(model, optimizer, images, labels)\n",
        "    avg_loss.update_state(loss)\n",
        "    if tf.equal(optimizer.iterations % log_freq, 0):\n",
        "      tf.summary.scalar('loss', avg_loss.result(), step=optimizer.iterations)\n",
        "      avg_loss.reset_states()\n",
        "\n",
        "def test(model, test_x, test_y, step_num):\n",
        "  # training=False is only needed if there are layers with different\n",
        "  # behavior during training versus inference (e.g. Dropout).\n",
        "  loss = loss_fn(model(test_x, training=False), test_y)\n",
        "  tf.summary.scalar('loss', loss, step=step_num)\n",
        "\n",
        "train_summary_writer = tf.summary.create_file_writer('/tmp/summaries/train')\n",
        "test_summary_writer = tf.summary.create_file_writer('/tmp/summaries/test')\n",
        "\n",
        "with train_summary_writer.as_default():\n",
        "  train(model, optimizer, dataset)\n",
        "\n",
        "with test_summary_writer.as_default():\n",
        "  test(model, test_x, test_y, optimizer.iterations)\n",
        "```\n",
        "\n",
        "Visualize the generated summaries by pointing TensorBoard to the summary log\n",
        "directory:\n",
        "\n",
        "```shell\n",
        "tensorboard --logdir /tmp/summaries\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0tx7FyM_RHwJ"
      },
      "source": [
        "Use the `tf.summary` API to write summary data for visualization in TensorBoard. For more info, read the [`tf.summary` guide](https://www.tensorflow.org/tensorboard/migrate#in_tf_2x)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HAbA0fKW58CH"
      },
      "outputs": [],
      "source": [
        "# Create the metrics\n",
        "loss_metric = tf.keras.metrics.Mean(name='train_loss')\n",
        "accuracy_metric = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')\n",
        "\n",
        "@tf.function\n",
        "def train_step(inputs, labels):\n",
        "  with tf.GradientTape() as tape:\n",
        "    predictions = model(inputs, training=True)\n",
        "    regularization_loss=tf.math.add_n(model.losses)\n",
        "    pred_loss=loss_fn(labels, predictions)\n",
        "    total_loss=pred_loss + regularization_loss\n",
        "\n",
        "  gradients = tape.gradient(total_loss, model.trainable_variables)\n",
        "  optimizer.apply_gradients(zip(gradients, model.trainable_variables))\n",
        "  # Update the metrics\n",
        "  loss_metric.update_state(total_loss)\n",
        "  accuracy_metric.update_state(labels, predictions)\n",
        "\n",
        "\n",
        "for epoch in range(NUM_EPOCHS):\n",
        "  # Reset the metrics\n",
        "  loss_metric.reset_states()\n",
        "  accuracy_metric.reset_states()\n",
        "\n",
        "  for inputs, labels in train_data:\n",
        "    train_step(inputs, labels)\n",
        "  # Get the metric results\n",
        "  mean_loss=loss_metric.result()\n",
        "  mean_accuracy = accuracy_metric.result()\n",
        "\n",
        "  print('Epoch: ', epoch)\n",
        "  print('  loss:     {:.3f}'.format(mean_loss))\n",
        "  print('  accuracy: {:.3f}'.format(mean_accuracy))\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bG9AaMzih3eh"
      },
      "source": [
        "#### Keras metric names\n",
        "<a name=\"keras_metric_names\"></a>\n",
        "\n",
        "Keras models are consistent about handling metric names. When you pass a string in the list of metrics, that _exact_ string is used as the metric's `name`. These names are visible in the history object returned by `model.fit`, and in the logs passed to `keras.callbacks`. is set to the string you passed in the metric list. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1iODIsGDgyYd"
      },
      "outputs": [],
      "source": [
        "model.compile(\n",
        "    optimizer = tf.keras.optimizers.Adam(0.001),\n",
        "    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
        "    metrics = ['acc', 'accuracy', tf.keras.metrics.SparseCategoricalAccuracy(name=\"my_accuracy\")])\n",
        "history = model.fit(train_data)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8oGzs_TlisKJ"
      },
      "outputs": [],
      "source": [
        "history.history.keys()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JaB2z2XIyhcr"
      },
      "source": [
        "### Debugging\n",
        "\n",
        "Use eager execution to run your code step-by-step to inspect shapes, data types and values. Certain APIs, like `tf.function`, `tf.keras`,\n",
        "etc. are designed to use Graph execution, for performance and portability. When\n",
        "debugging, use `tf.config.run_functions_eagerly(True)` to use eager execution\n",
        "inside this code.\n",
        "\n",
        "For example:\n",
        "\n",
        "```python\n",
        "@tf.function\n",
        "def f(x):\n",
        "  if x > 0:\n",
        "    import pdb\n",
        "    pdb.set_trace()\n",
        "    x = x + 1\n",
        "  return x\n",
        "\n",
        "tf.config.run_functions_eagerly(True)\n",
        "f(tf.constant(1))\n",
        "```\n",
        "\n",
        "```\n",
        ">>> f()\n",
        "-> x = x + 1\n",
        "(Pdb) l\n",
        "  6     @tf.function\n",
        "  7     def f(x):\n",
        "  8       if x > 0:\n",
        "  9         import pdb\n",
        " 10         pdb.set_trace()\n",
        " 11  ->     x = x + 1\n",
        " 12       return x\n",
        " 13\n",
        " 14     tf.config.run_functions_eagerly(True)\n",
        " 15     f(tf.constant(1))\n",
        "[EOF]\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gdvGF2FvbBXZ"
      },
      "source": [
        "This also works inside Keras models and other APIs that support eager execution:\n",
        "\n",
        "```python\n",
        "class CustomModel(tf.keras.models.Model):\n",
        "\n",
        "  @tf.function\n",
        "  def call(self, input_data):\n",
        "    if tf.reduce_mean(input_data) > 0:\n",
        "      return input_data\n",
        "    else:\n",
        "      import pdb\n",
        "      pdb.set_trace()\n",
        "      return input_data // 2\n",
        "\n",
        "\n",
        "tf.config.run_functions_eagerly(True)\n",
        "model = CustomModel()\n",
        "model(tf.constant([-2, -4]))\n",
        "```\n",
        "\n",
        "```\n",
        ">>> call()\n",
        "-> return input_data // 2\n",
        "(Pdb) l\n",
        " 10         if tf.reduce_mean(input_data) > 0:\n",
        " 11           return input_data\n",
        " 12         else:\n",
        " 13           import pdb\n",
        " 14           pdb.set_trace()\n",
        " 15  ->       return input_data // 2\n",
        " 16\n",
        " 17\n",
        " 18     tf.config.run_functions_eagerly(True)\n",
        " 19     model = CustomModel()\n",
        " 20     model(tf.constant([-2, -4]))\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "S0-F-bvJXKD8"
      },
      "source": [
        "Notes:\n",
        "* `tf.keras.Model` methods such as `fit`, `evaluate`, and `predict` execute as [graphs](https://www.tensorflow.org/guide/intro_to_graphs) with `tf.function` under the hood.\n",
        "\n",
        "* When using `tf.keras.Model.compile`, set `run_eagerly = True` to disable the `Model` logic from being wrapped in a `tf.function`.\n",
        "\n",
        "* Use `tf.data.experimental.enable_debug_mode` to enable the debug mode for `tf.data`. Read the [API docs](https://www.tensorflow.org/api_docs/python/tf/data/experimental/enable_debug_mode) for more details.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wxa5yKK7bym0"
      },
      "source": [
        "### Do not keep `tf.Tensors` in your objects\n",
        "\n",
        "These tensor objects might get created either in a `tf.function` or in the eager context, and these tensors behave differently. Always use `tf.Tensor`s only for intermediate values.\n",
        "\n",
        "To track state, use `tf.Variable`s as they are always usable from both contexts. Read the [`tf.Variable` guide](https://www.tensorflow.org/guide/variable) to learn more.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FdXLLYa2uAyx"
      },
      "source": [
        "## Resources and further reading\n",
        "\n",
        "* Read the TF2 [guides](https://tensorflow.org/guide) and [tutorials](https://tensorflow.org/tutorials) to learn more about how to use TF2.\n",
        "\n",
        "* If you previously used TF1.x, it is highly recommended you migrate your code to TF2. Read the [migration guides](https://tensorflow.org/guide/migrate) to learn more."
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "effective_tf2.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
